{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# snnTorch - Tutorial 3\n",
    "### By Jason K. Eshraghian"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Gradient-based Learning in Convolutional Spiking Neural Networks\n",
    "In this tutorial, we'll use a convolutional neural network (CNN) to classify the MNIST dataset.\n",
    "We will use the backpropagation through time (BPTT) algorithm to do so. This tutorial is largely the same as tutorial 2, just with a different network architecture to show how to integrate convolutions with snnTorch.\n",
    "\n",
    "If running in Google Colab:\n",
    "* Ensure you are connected to GPU by checking Runtime > Change runtime type > Hardware accelerator: GPU\n",
    "* Next, install the Test PyPi distribution of snnTorch by clicking into the following cell and pressing `Shift+Enter`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# Install the test PyPi Distribution of snntorch\n",
    "!pip install -i https://test.pypi.org/simple/ snntorch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## 1. Setting up the Static MNIST Dataset\n",
    "### 1.1. Import packages and setup environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import snntorch as snn\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import datasets, transforms\n",
    "import numpy as np\n",
    "import itertools\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.2 Define network and SNN parameters\n",
    "We will use a 2conv-2MaxPool-FCN architecture for a sequence of 25 time steps.\n",
    "\n",
    "* `alpha` is the decay rate of the synaptic current of a neuron\n",
    "* `beta` is the decay rate of the membrane potential of a neuron"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# Network Architecture\n",
    "num_inputs = 28*28\n",
    "num_outputs = 10\n",
    "\n",
    "# Training Parameters\n",
    "batch_size=128\n",
    "data_path='/data/mnist'\n",
    "\n",
    "# Temporal Dynamics\n",
    "num_steps = 25\n",
    "time_step = 1e-3\n",
    "tau_mem = 6.5e-4\n",
    "tau_syn = 5.5e-4\n",
    "alpha = float(np.exp(-time_step/tau_syn))\n",
    "beta = float(np.exp(-time_step/tau_mem))\n",
    "\n",
    "dtype = torch.float\n",
    "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### 1.3 Download MNIST Dataset\n",
    "To see how to construct a validation set, refer to Tutorial 1."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# Define a transform\n",
    "transform = transforms.Compose([\n",
    "            transforms.Resize((28, 28)),\n",
    "            transforms.Grayscale(),\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize((0,), (1,))])\n",
    "\n",
    "mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)\n",
    "mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### 1.4 Create DataLoaders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)\n",
    "test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## 2. Define Network\n",
    "snnTorch contains a series of neuron models and related functions to ease the training process.\n",
    "Neurons are treated as activations with recurrent connections, and integrate smoothly with PyTorch's pre-existing layer functions.\n",
    "* `snntorch.Stein` is a simple Leaky Integrate and Fire (LIF) neuron. Specifically, it uses Stein's model which assumes instantaneous rise times for synaptic current and membrane potential.\n",
    "* `snntorch.FastSigmoidSurrogate` defines separate forward and backward functions. The forward function is a Heaviside step function for spike generation. The backward function is the derivative of a fast sigmoid function, to ensure continuous differentiability.\n",
    "FSS is mostly derived from:\n",
    "\n",
    ">Neftci, E. O., Mostafa, H., and Zenke, F. (2019) Surrogate Gradient Learning in Spiking Neural Networks. https://arxiv.org/abs/1901/09948\n",
    "\n",
    "There are a few other surrogate gradient functions included.\n",
    "`snn.slope` is a variable that defines the slope of the backward surrogate.\n",
    "TO-DO: Include visualisation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "spike_grad = snn.FastSigmoidSurrogate.apply\n",
    "snn.slope = 50"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Now we can define our spiking neural network (SNN).\n",
    "If you have already worked through Tutorial 2, you may wish to skip ahead.\n",
    "\n",
    "Creating an instance of the `Stein` neuron requires two compulsory arguments and two optional arguments:\n",
    "1. $I_{syn}$ decay rate, $\\alpha$,\n",
    "2. $V_{mem}$ decay rate, $\\beta$,\n",
    "3. *optional*: the surrogate spiking function, `spike_grad` (*default*: the gradient of the Heaviside function), and\n",
    "4. *optional*: the threshold for spiking, (*default*: 1.0).\n",
    "\n",
    "snnTorch treats the LIF neuron as a recurrent activation. Therefore, it requires initialization of its internal states.\n",
    "For each layer, we initialize the synaptic current `syn1` and `syn2`, the membrane potential `mem1` and `mem2`, and the post-synaptic spikes `spk1` and `spk2` to zero.\n",
    "A class method `init_stein` will take care of this.\n",
    "\n",
    "For rate coding, the final layer of spikes and membrane potential are used to determine accuracy and loss, respectively.\n",
    "So their historical values are recorded in `spk3_rec` and `mem3_rec`.\n",
    "\n",
    "Keep in mind, the dataset we are using is just static MNIST. I.e., it is *not* time-varying.\n",
    "Therefore, we pass the same MNIST sample to the input at each time step.\n",
    "This is handled with `cur1 = F.max_pool2d(self.conv1(x), 2)`, where `x` is the same input over the whole for-loop."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "    # initialize layers\n",
    "        self.conv1 = nn.Conv2d(in_channels=1, out_channels=12, kernel_size=5, stride=1, padding=1)\n",
    "        self.lif1 = snn.Stein(alpha=alpha, beta=beta, spike_grad=spike_grad)\n",
    "        self.conv2 = nn.Conv2d(in_channels=12, out_channels=64, kernel_size=5, stride=1, padding=1)\n",
    "        self.lif2 = snn.Stein(alpha=alpha, beta=beta, spike_grad=spike_grad)\n",
    "        self.fc2 = nn.Linear(64*5*5, 10)\n",
    "        self.lif3 = snn.Stein(alpha=alpha, beta=beta, spike_grad=spike_grad)\n",
    "\n",
    "    def forward(self, x):\n",
    "        # Initialize LIF state variables and spike output tensors\n",
    "        spk1, syn1, mem1 = self.lif1.init_hidden(batch_size, 12, 13, 13)\n",
    "        spk2, syn2, mem2 = self.lif1.init_hidden(batch_size, 64, 5, 5)\n",
    "        spk3, syn3, mem3 = self.lif2.init_hidden(batch_size, 10)\n",
    "\n",
    "        spk3_rec = []\n",
    "        mem3_rec = []\n",
    "\n",
    "        for step in range(num_steps):\n",
    "            cur1 = F.max_pool2d(self.conv1(x), 2)\n",
    "            spk1, syn1, mem1 = self.lif1(cur1, syn1, mem1)\n",
    "            cur2 = F.max_pool2d(self.conv2(spk1), 2)\n",
    "            spk2, syn2, mem2 = self.lif2(cur2, syn2, mem2)\n",
    "            cur3 = self.fc2(spk2.view(batch_size, -1))\n",
    "            spk3, syn3, mem3 = self.lif3(cur3, syn3, mem3)\n",
    "\n",
    "            spk3_rec.append(spk3)\n",
    "            mem3_rec.append(mem3)\n",
    "\n",
    "        return torch.stack(spk3_rec, dim=0), torch.stack(mem3_rec, dim=0)\n",
    "\n",
    "net = Net().to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## 3. Training\n",
    "Time for training! Let's first define a couple of functions to print out test/train accuracy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def print_batch_accuracy(data, targets, train=False):\n",
    "    output, _ = net(data.view(batch_size, 1, 28, 28))\n",
    "    _, idx = output.sum(dim=0).max(1)\n",
    "    acc = np.mean((targets == idx).detach().cpu().numpy())\n",
    "\n",
    "    if train:\n",
    "        print(f\"Train Set Accuracy: {acc}\")\n",
    "    else:\n",
    "        print(f\"Test Set Accuracy: {acc}\")\n",
    "\n",
    "def train_printer():\n",
    "    print(f\"Epoch {epoch}, Minibatch {minibatch_counter}\")\n",
    "    print(f\"Train Set Loss: {loss_hist[counter]}\")\n",
    "    print(f\"Test Set Loss: {test_loss_hist[counter]}\")\n",
    "    print_batch_accuracy(data_it, targets_it, train=True)\n",
    "    print_batch_accuracy(testdata_it, testtargets_it, train=False)\n",
    "    print(\"\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### 3.1 Optimizer & Loss\n",
    "* *Output Activation*: We'll apply the softmax function to the membrane potentials of the output layer, rather than the spikes.\n",
    "* *Loss*: This will then be used to calculate the negative log-likelihood loss.\n",
    "By encouraging the membrane of the correct neuron class to reach the threshold, we expect that neuron will fire more frequently.\n",
    "The loss could be applied to the spike count as well, but the membrane is  continuous whereas spike count is discrete.\n",
    "* *Optimizer*: The Adam optimizer is used for weight updates.\n",
    "* *Accuracy*: Accuracy is measured by counting the spikes of the output neurons. The neuron that fires the most frequently will be our predicted class."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "optimizer = torch.optim.Adam(net.parameters(), lr=2e-4, betas=(0.9, 0.999))\n",
    "log_softmax_fn = nn.LogSoftmax(dim=-1)\n",
    "loss_fn = nn.NLLLoss()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### 3.2 Training Loop\n",
    "Now just sit back, relax, and wait for convergence."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "loss_hist = []\n",
    "test_loss_hist = []\n",
    "counter = 0\n",
    "\n",
    "# Outer training loop\n",
    "for epoch in range(5):\n",
    "\n",
    "    minibatch_counter = 0\n",
    "    train_batch = iter(train_loader)\n",
    "\n",
    "    # Minibatch training loop\n",
    "    for data_it, targets_it in train_batch:\n",
    "        data_it = data_it.to(device)\n",
    "        targets_it = targets_it.to(device)\n",
    "\n",
    "        output, mem_rec = net(data_it.view(batch_size, 1, 28, 28)) # [28x28] or [1x28x28]?\n",
    "        log_p_y = log_softmax_fn(mem_rec)\n",
    "        loss_val = torch.zeros((1), dtype=dtype, device=device)\n",
    "\n",
    "        # Sum loss over time steps to perform BPTT\n",
    "        for step in range(num_steps):\n",
    "          loss_val += loss_fn(log_p_y[step], targets_it)\n",
    "\n",
    "        # Gradient calculation\n",
    "        optimizer.zero_grad()\n",
    "        loss_val.backward(retain_graph=True)\n",
    "\n",
    "        # Weight Update\n",
    "        nn.utils.clip_grad_norm_(net.parameters(), 1)\n",
    "        optimizer.step()\n",
    "\n",
    "        # Store loss history for future plotting\n",
    "        loss_hist.append(loss_val.item())\n",
    "\n",
    "        # Test set\n",
    "        test_data = itertools.cycle(test_loader)\n",
    "        testdata_it, testtargets_it = next(test_data)\n",
    "        testdata_it = testdata_it.to(device)\n",
    "        testtargets_it = testtargets_it.to(device)\n",
    "\n",
    "        # Test set forward pass\n",
    "        test_output, test_mem_rec = net(testdata_it.view(batch_size, 1, 28, 28))\n",
    "\n",
    "        # Test set loss\n",
    "        log_p_ytest = log_softmax_fn(test_mem_rec)\n",
    "        log_p_ytest = log_p_ytest.sum(dim=0)\n",
    "        loss_val_test = loss_fn(log_p_ytest, testtargets_it)\n",
    "        test_loss_hist.append(loss_val_test.item())\n",
    "\n",
    "        # Print test/train loss/accuracy\n",
    "        if counter % 50 == 0:\n",
    "          train_printer()\n",
    "        minibatch_counter += 1\n",
    "        counter += 1\n",
    "\n",
    "loss_hist_true_grad = loss_hist\n",
    "test_loss_hist_true_grad = test_loss_hist"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## 4. Results\n",
    "### 4.1 Plot Training/Test Loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# Plot Loss\n",
    "fig = plt.figure(facecolor=\"w\", figsize=(10, 5))\n",
    "plt.plot(loss_hist)\n",
    "plt.plot(test_loss_hist)\n",
    "plt.legend([\"Test Loss\", \"Train Loss\"])\n",
    "plt.xlabel(\"Epoch\")\n",
    "plt.ylabel(\"Loss\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### 4.2 Test Set Accuracy\n",
    "This function just iterates over all minibatches to obtain a measure of accuracy over the full 10,000 samples in the test set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "total = 0\n",
    "correct = 0\n",
    "test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=False)\n",
    "\n",
    "with torch.no_grad():\n",
    "  net.eval()\n",
    "  for data in test_loader:\n",
    "    images, labels = data\n",
    "    images = images.to(device)\n",
    "    labels = labels.to(device)\n",
    "\n",
    "    # If current batch matches batch_size, just do the usual thing\n",
    "    if images.size()[0] == batch_size:\n",
    "      outputs, _ = net(images.view(batch_size, 1, 28, 28))\n",
    "\n",
    "    # If current batch does not match batch_size (e.g., is the final minibatch),\n",
    "    # modify batch_size in a temp variable and restore it at the end of the else block\n",
    "    else:\n",
    "      temp_bs = batch_size\n",
    "      batch_size = images.size()[0]\n",
    "      outputs, _ = net(images.view(images.size()[0], 1, 28, 28))\n",
    "      batch_size = temp_bs\n",
    "\n",
    "    _, predicted = outputs.sum(dim=0).max(1)\n",
    "    total += labels.size(0)\n",
    "    correct += (predicted == labels).sum().item()\n",
    "\n",
    "print(f\"Total correctly classified test set images: {correct}/{total}\")\n",
    "print(f\"Test Set Accuracy: {100 * correct / total}%\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "That's it for static MNIST!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## 5. Spiking MNIST\n",
    "As before, there isn't anything all that impressive about training a network on the static MNIST dataset.\n",
    "So let's apply rate-coding to convert it into a time-varying stream of spikes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from snntorch import spikegen\n",
    "\n",
    "# MNIST to spiking-MNIST\n",
    "spike_data, spike_targets = spikegen.rate(data_it, targets_it, num_outputs=num_outputs, num_steps=num_steps,\n",
    "                                                      gain=1, offset=0, convert_targets=False, temporal_targets=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### 5.1 Visualiser\n",
    "Just so you're damn sure it's a spiking input."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "!pip install celluloid # matplotlib animations made easy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note: if you are running the notebook locally on your desktop, please uncomment the line below and modify the path to your ffmpeg.exe"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from celluloid import Camera\n",
    "from IPython.display import HTML\n",
    "\n",
    "# Animator\n",
    "spike_data_sample = spike_data[:, 0, 0].cpu()\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "camera = Camera(fig)\n",
    "plt.axis('off')\n",
    "\n",
    "# plt.rcParams['animation.ffmpeg_path'] = 'C:\\\\path\\\\to\\\\your\\\\ffmpeg.exe'\n",
    "\n",
    "for step in range(num_steps):\n",
    "    im = ax.imshow(spike_data_sample[step, :, :], cmap='plasma')\n",
    "    camera.snap()\n",
    "\n",
    "# interval=40 specifies 40ms delay between frames\n",
    "a = camera.animate(interval=40)\n",
    "HTML(a.to_html5_video())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "print(spike_targets[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## 6. Define Network\n",
    "The convolutional network is the same as before. The one difference is that the for-loop iterates through the first dimension of the input:\n",
    "`cur1 = F.max_pool2d(self.conv1(x[step]), 2)`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "spike_grad = snn.FastSigmoidSurrogate.apply\n",
    "snn.slope = 50\n",
    "\n",
    "# Define a different network\n",
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "        # initialize layers\n",
    "        self.conv1 = nn.Conv2d(in_channels=1, out_channels=12, kernel_size=5, stride=1, padding=1)\n",
    "        self.lif1 = snn.Stein(alpha=alpha, beta=beta, spike_grad=spike_grad)\n",
    "        self.conv2 = nn.Conv2d(in_channels=12, out_channels=64, kernel_size=5, stride=1, padding=1)\n",
    "        self.lif1 = snn.Stein(alpha=alpha, beta=beta, spike_grad=spike_grad)\n",
    "        self.fc2 = nn.Linear(64*5*5, 10)\n",
    "        self.lif1 = snn.Stein(alpha=alpha, beta=beta, spike_grad=spike_grad)\n",
    "\n",
    "        # self.conv1 = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=3, stride=1, padding=0)\n",
    "        # self.lif1 = LIF(spike_fn=spike_fn, alpha=alpha, beta=beta)\n",
    "        # self.fc1 = nn.Linear(26*26*3, 10)\n",
    "        # self.lif2 = LIF(spike_fn=spike_fn, alpha=alpha, beta=beta)\n",
    "\n",
    "    def forward(self, x):\n",
    "        # Initialize LIF state variables and spike output tensors\n",
    "        spk1, syn1, mem1 = self.lif1.init_stein(batch_size, 12, 13, 13)\n",
    "        spk2, syn2, mem2 = self.lif1.init_stein(batch_size, 64, 5, 5)\n",
    "        spk3, syn3, mem3 = self.lif2.init_stein(batch_size, 10)\n",
    "\n",
    "        spk3_rec = []\n",
    "        mem3_rec = []\n",
    "\n",
    "        for step in range(num_steps):\n",
    "            cur1 = F.max_pool2d(self.conv1(x[step]), 2) # add max-pooling to membrane or spikes?\n",
    "            spk1, syn1, mem1 = self.lif1(cur1, syn1, mem1)\n",
    "            cur2 = F.max_pool2d(self.conv2(spk1), 2)\n",
    "            spk2, syn2, mem2 = self.lif2(cur2, syn2, mem2)\n",
    "            cur3 = self.fc2(spk2.view(batch_size, -1))\n",
    "            spk3, syn3, mem3 = self.lif3(cur3, syn3, mem3)\n",
    "\n",
    "            spk3_rec.append(spk3)\n",
    "            mem3_rec.append(mem3)\n",
    "\n",
    "        return torch.stack(spk3_rec, dim=0), torch.stack(mem3_rec, dim=0)\n",
    "\n",
    "net = Net().to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## 7. Training\n",
    "We make a slight modification to our print-out functions to handle the new first dimension of the input:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def print_batch_accuracy(data, targets, train=False):\n",
    "    output, _ = net(data.view(num_steps, batch_size, 1, 28, 28))\n",
    "    _, idx = output.sum(dim=0).max(1)\n",
    "    acc = np.mean((targets == idx).detach().cpu().numpy())\n",
    "\n",
    "    if train:\n",
    "        print(f\"Train Set Accuracy: {acc}\")\n",
    "    else:\n",
    "        print(f\"Test Set Accuracy: {acc}\")\n",
    "\n",
    "def train_printer():\n",
    "    print(f\"Epoch {epoch}, Minibatch {minibatch_counter}\")\n",
    "    print(f\"Train Set Loss: {loss_hist[counter]}\")\n",
    "    print(f\"Test Set Loss: {test_loss_hist[counter]}\")\n",
    "    print_batch_accuracy(spike_data, spike_targets, train=True)\n",
    "    print_batch_accuracy(test_spike_data, test_spike_targets, train=False)\n",
    "    print(\"\\n\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### 7.1 Optimizer & Loss\n",
    "We'll keep our optimizer and loss the exact same as the static MNIST case."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "optimizer = torch.optim.Adam(net.parameters(), lr=2e-4, betas=(0.9, 0.999))\n",
    "log_softmax_fn = nn.LogSoftmax(dim=-1)\n",
    "loss_fn = nn.NLLLoss()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### 7.2 Training Loop\n",
    "The training loop is identical to the static MNIST case, but we pass each minibatch through `spikegen.rate` before running it through the feedforward network."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "loss_hist = []\n",
    "test_loss_hist = []\n",
    "counter = 0\n",
    "\n",
    "# Outer training loop\n",
    "for epoch in range(5):\n",
    "    minibatch_counter = 0\n",
    "    data = iter(train_loader)\n",
    "\n",
    "    # Minibatch training loop\n",
    "    for data_it, targets_it in data:\n",
    "        data_it = data_it.to(device)\n",
    "        targets_it = targets_it.to(device)\n",
    "\n",
    "        # Spike generator\n",
    "        spike_data, spike_targets = spikegen.rate(data_it, targets_it, num_outputs=num_outputs, num_steps=num_steps,\n",
    "                                                  gain=1, offset=0, convert_targets=False, temporal_targets=False)\n",
    "\n",
    "        # Forward pass\n",
    "        output, mem_rec = net(spike_data.view(num_steps, batch_size, 1, 28, 28))\n",
    "        log_p_y = log_softmax_fn(mem_rec)\n",
    "        loss_val = torch.zeros(1, dtype=dtype, device=device)\n",
    "\n",
    "        # Sum loss over time steps to perform BPTT\n",
    "        for step in range(num_steps):\n",
    "          loss_val += loss_fn(log_p_y[step], targets_it)\n",
    "\n",
    "        # Gradient Calculation\n",
    "        optimizer.zero_grad()\n",
    "        loss_val.backward(retain_graph=True)\n",
    "        nn.utils.clip_grad_norm_(net.parameters(), 1)\n",
    "\n",
    "        # Weight Update\n",
    "        optimizer.step()\n",
    "\n",
    "        # Store Loss history\n",
    "        loss_hist.append(loss_val.item())\n",
    "\n",
    "        # Test set\n",
    "        test_data = itertools.cycle(test_loader)\n",
    "        testdata_it, testtargets_it = next(test_data)\n",
    "        testdata_it = testdata_it.to(device)\n",
    "        testtargets_it = testtargets_it.to(device)\n",
    "\n",
    "        # Test set spike conversion\n",
    "        test_spike_data, test_spike_targets = spikegen.rate(testdata_it, testtargets_it, num_outputs=num_outputs,\n",
    "                                                            num_steps=num_steps, gain=1, offset=0, convert_targets=False,\n",
    "                                                            temporal_targets=False)\n",
    "\n",
    "        # Test set forward pass\n",
    "        test_output, test_mem_rec = net(test_spike_data.view(num_steps, batch_size, 1, 28, 28))\n",
    "\n",
    "        # Test set loss\n",
    "        log_p_ytest = log_softmax_fn(test_mem_rec)\n",
    "        log_p_ytest = log_p_ytest.sum(dim=0)\n",
    "        loss_val_test = loss_fn(log_p_ytest, test_spike_targets)\n",
    "        test_loss_hist.append(loss_val_test.item())\n",
    "\n",
    "        # Print test/train loss/accuracy\n",
    "        if counter % 50 == 0:\n",
    "            train_printer()\n",
    "        minibatch_counter += 1\n",
    "        counter += 1\n",
    "\n",
    "loss_hist_true_grad = loss_hist\n",
    "test_loss_hist_true_grad = test_loss_hist"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 8. Spiking MNIST Results\n",
    "### 8.1 Plot Training/Test Loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# Plot Loss\n",
    "fig = plt.figure(facecolor=\"w\", figsize=(10, 5))\n",
    "plt.plot(loss_hist)\n",
    "plt.plot(test_loss_hist)\n",
    "plt.legend([\"Test Loss\", \"Train Loss\"])\n",
    "plt.xlabel(\"Epoch\")\n",
    "plt.ylabel(\"Loss\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 8.2 Test Set Accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "total = 0\n",
    "correct = 0\n",
    "test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=False)\n",
    "\n",
    "with torch.no_grad():\n",
    "  net.eval()\n",
    "  for data in test_loader:\n",
    "    images, labels = data\n",
    "    images = images.to(device)\n",
    "    labels = labels.to(device)\n",
    "\n",
    "    # If current batch matches batch_size, just do the usual thing\n",
    "    if images.size()[0] == batch_size:\n",
    "      spike_test, spike_targets = spikegen.rate(images, labels, num_outputs=num_outputs, num_steps=num_steps,\n",
    "                                                            gain=1, offset=0, convert_targets=False, temporal_targets=False)\n",
    "\n",
    "      outputs, _ = net(spike_test.view(num_steps, batch_size, 1, 28, 28))\n",
    "\n",
    "    # If current batch does not match batch_size (e.g., is the final minibatch),\n",
    "    # modify batch_size in a temp variable and restore it at the end of the else block\n",
    "    else:\n",
    "      temp_bs = batch_size\n",
    "      batch_size = images.size()[0]\n",
    "      spike_test, spike_targets = spikegen.rate(images, labels, num_outputs=num_outputs, num_steps=num_steps,\n",
    "                                                            gain=1, offset=0, convert_targets=False, temporal_targets=False)\n",
    "      outputs, _ = net(spike_test.view(num_steps, images.size()[0], 1, 28, 28))\n",
    "      batch_size = temp_bs\n",
    "\n",
    "    _, predicted = outputs.sum(dim=0).max(1)\n",
    "    total += spike_targets.size(0)\n",
    "    correct += (predicted == spike_targets).sum().item()\n",
    "\n",
    "print(f\"Total correctly classified test set images: {correct}/{total}\")\n",
    "print(f\"Test Set Accuracy: {100 * correct / total}%\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To-do:\n",
    "* Add figures to explain in better detail\n",
    "* See if SRM0 model can reach acceptable acc within reasonable num_steps\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
